In this document, we illustrate R functions for Bayesian marginal additive models (BMAM).

Brief overview of Bayesian marginal additive models

Heagerty and Zeger (2000) presented the marginal models in which the marginal mean, rather than the conditional mean given random effects, is regressed on covariates. Marginal models are useful when we are interested in associations averaged across a population of clusters. If the covariate-outcome association is unknown, we could use additive models to measure the potentially non-linear relationships. McGee and Stringer (2022) proposed a marginal additive model (MAM) for modeling dependent data with non-linear population-averaged associations.

We propose a Bayesian marginal additive model (BMAM) extending the marginal additive model (MAM) into a Bayesian context.

Bayesian conditional models

We consider a Generalized Additive Mixed Model (GAMM), \[ g\left\{\mu^{\mathrm{C}}\left(\boldsymbol{x}_{i j}, \boldsymbol{z}_{i j} \mid \boldsymbol{u}_{i}\right)\right\}=\sum_{l=1}^{p} f_{l}^{\mathrm{C}}\left(x_{i j l}\right)+\boldsymbol{z}_{i j}^{\mathrm{T}} \boldsymbol{u}_{i} \] where \(i = 1,\cdots,N\) indexes clusters and \(j = 1, \cdots, n_i\) indexes units within clusters. Random effects \(\boldsymbol{u}_{i}\) follow a normal distribution with mean 0 and covariance matrix \(\boldsymbol{\Sigma}=\boldsymbol{\Lambda}\left(\boldsymbol{\sigma}\right) \boldsymbol{\Omega} \boldsymbol{\Lambda}\left(\boldsymbol{\sigma}\right)\), where \(\boldsymbol{\Lambda}\left(\boldsymbol{\sigma}\right)\) denotes the diagonal matrix with diagonal elements \(\boldsymbol{\sigma}\).

\(f_{l}^{\mathrm{C}}(x)=\sum_{q=1}^{d_{l}} b_{l q}^{\mathrm{C}}(x) \alpha_{l q}^{\mathrm{C}}\) are unknown smooth functions. Each \(\boldsymbol{\alpha}_{l}^{\mathrm{C}}\) is associated with a penalty \(\mathcal{P}_{\tau_{l}^{\mathrm{C}}}\left(\boldsymbol{\alpha}_{l}^{\mathrm{C}} ; \boldsymbol{S}_{l}\right) \propto \mathrm{exp} \left( - \tau_{l} \boldsymbol{\alpha}_{l}^{\mathrm{C}^\mathrm{T}} \boldsymbol{S}_{l} \boldsymbol{\alpha}_{l}^{\mathrm{C}} \right)\).

According to Wood (2006), \(\sum_{q=1}^{d_{l}} b_{l q}^{\mathrm{C}}(x) \alpha_{l q}^{\mathrm{C}}\) can be written as \(\sum_{q=1}^{d_{l}} b_{l q}^{*\mathrm{C}}(x) \alpha_{l q}^{*\mathrm{C}}\). Then the penalty can be termed as a prior distribution, \[ \alpha_{l k}^{*\mathrm{C}} \sim \mathrm{N}(0,\frac{1}{\tau_{l}}), k=3,\cdots,d_{l}, \] where \(\left[ b_{l 1}^{*\mathrm{C}}(x), b_{l 2}^{*\mathrm{C}}(x), \cdots, b_{l d_{l}}^{*\mathrm{C}}(x)\right] = \left[ b_{l 1}^{\mathrm{C}}(x), b_{l 2}^{\mathrm{C}}(x), \cdots, b_{l d_{l}}^{\mathrm{C}}(x)\right] \mathbf{D}_{l+}^{-1}\), \(\mathbf{D}_{l+} = \left[\begin{array}{cc} \mathbf{I}_{2} \ \ \mathbf{0} \\ \mathbf{D}_l \end{array}\right]\). and \(\mathbf{D}_l\) is a \((d_{l}-2)\times d_{l}\) matrix with \(\mathbf{D}_l^\mathrm{T} \mathbf{D}_l = S_l\)

Thus, the GAMM can be rewritten as,

\[ \begin{aligned} g\left\{\mu^{\mathrm{C}}\left(\boldsymbol{x}_{i j}, \boldsymbol{z}_{i j} \mid \boldsymbol{u}_{i}\right)\right\}&=\sum_{l=1}^{p} \sum_{q=1}^{d_{l}} b_{l q}^{*\mathrm{C}}(x) \alpha_{l q}^{*\mathrm{C}}+\boldsymbol{z}_{i j}^{\mathrm{T}} \boldsymbol{u}_{i},\\ \boldsymbol{u}_{i} &\sim \mathrm{N}(0,\boldsymbol{\Lambda}\left(\boldsymbol{\sigma}\right) \boldsymbol{\Omega} \boldsymbol{\Lambda}\left(\boldsymbol{\sigma}\right)),\\ \alpha_{l k}^{*\mathrm{C}} &\sim \mathrm{N}(0,\frac{1}{\tau_{l}}), k=3,\cdots,d_{l}. \end{aligned} \] Besides the two parameters \(\boldsymbol{u}_{i}\) and \(\alpha_{l k}^{*\mathrm{C}},k=3,\cdots,d_{l}\), in a Bayesian model, we need to assign prior distributions for \(\alpha_{l 1}^{*\mathrm{C}},\alpha_{l 2}^{*\mathrm{C}}\), \(\tau_l\), \(\boldsymbol{\sigma}\) and \(\boldsymbol{\Omega}\).

Our bmam functions are built on brms package. The prior distributions can be set in brms function. By default, \(\alpha_{l 1}^{*\mathrm{C}},\alpha_{l 2}^{*\mathrm{C}}\) have an improper flat prior, and \(\sqrt{1/\tau_l}\), the standard deviation in prior distribution of \(\boldsymbol{u}_{i}\), has a half student-t prior with 3 degree of freedom. The prior distribution of \(\boldsymbol{\sigma}\) is also half student-t prior with 3 degree of freedom, and \(\Omega \sim \operatorname{LKJ}(1)\).

Bayesian marginal additive models

The Bayesian marginal additive model (BMAM) is, \[ g\left\{\mu^{\mathrm{M}}\left(\boldsymbol{x}_{i j}\right)\right\}=f^{\mathrm{M}}\left(\boldsymbol{x}_{i j}\right) \] where \(f^{\mathrm{M}}\left(\boldsymbol{x}_{i j}\right) = \sum_{l=1}^{p} f_{l}^{\mathrm{M}}\left(x_{i j l}\right) = \sum_{l=1}^{p}\sum_{q=1}^{Q_{l}} b_{l q}^{\mathrm{M}}(x) \alpha_{l q}^{\mathrm{M}}\). The inference on parameters \(\alpha_{l q}^{\mathrm{M}}\) is based on the Bayesian conditional model through post-hoc marginalization (Hedeker et al. 2018). More details about the relationship between marginal additive models and conditional models can be found in McGee and Stringer (2022).

Specifically, we could draw the posterior samples from a fitted Bayesian conditional model. For each posterior sample, an estimate of \(\alpha_{lq}^{\mathrm{M}}\) could be obtained through post-hoc marginalization. Assuming \(H\) samples are drawn from posterior distribution in the Bayesian conditional model, we could obtain \(H\) estimates of \(\alpha_{lq}^{\mathrm{M}}\), and then make an inference on \(\alpha_{lq}^{\mathrm{M}}\).

Basic example of Bayesian marginal additive models

We use the function SimData to generate some data, to illustrate the main features of our R functions for BMAM. The data are generated according to the set-up in McGee and Stringer (2022). We consider the binary case, where the outcome variable \(Y\) follows a Bernoulli distribution. The link function \(g(\cdot)\) is set as the logit function.

We generate three covariates \(\boldsymbol{x}_{i j}=\left(x_{i j 1}, x_{i j 2}, x_{i j 3}\right)^{\mathrm{T}}\) where \(x_{i j p} \sim \operatorname{Unif}(-1,1)\) for \(p=1,2,3\). After assuming the forms of smooth functions \(f_{1}^{\mathrm{M}}(\cdot)\) and \(f_{2}^{\mathrm{M}}(\cdot)\), we could generate \(Y_{ij}\) according to the marginal model,

\[ \operatorname{logit}\left\{\mu^{\mathrm{M}}\left(\boldsymbol{x}_{i j}\right)\right\}=f_{1}^{\mathrm{M}}\left(x_{i j 1}\right)+f_{2}^{\mathrm{M}}\left(x_{i j 2}\right)+\beta_{3} \cdot x_{i j 3}. \] In addition, we have the following random intercepts and slopes dependence structure,

\[ \begin{aligned} \operatorname{logit}\left\{\mu^{\mathrm{C}}\left(\boldsymbol{x}_{i j} \mid u_{i}\right)\right\} &=\Delta\left(\boldsymbol{x}_{i j}\right)+u_{i 0}+u_{i 1} \cdot x_{i j 3} \\ {\left[\begin{array}{l} u_{i 0} \\ u_{i 1} \end{array}\right] } & \sim M V N\left(\left[\begin{array}{l} 0 \\ 0 \end{array}\right],\left[\begin{array}{cc} 4 & 1 \\ 1 & 1 \end{array}\right]\right) \end{aligned} \]

Here we set \(\beta_3 = 0\), the number of cluster \(N = 100\) and the number of units within clusters \(n_i = 10, i = 1, \ldots, N\).

set.seed(4321)
simdata <- SimData(K = 100, Nk = 10)
dat <- simdata$data
fun <- simdata$f

knitr::kable(head(dat), digits = 4)
id intercepts slopes x1 x2 x3 fx1 fx2 y
1 0.0042 0.7361 -0.330 0.961 0.170 -0.6311 -1.6801 0
1 0.0042 0.7361 0.818 -0.099 0.043 -0.8243 -0.0387 0
1 0.0042 0.7361 -0.177 -0.329 0.236 -0.4378 -0.3883 1
1 0.0042 0.7361 -0.912 -0.377 -0.130 1.4146 -0.4935 1
1 0.0042 0.7361 0.527 0.715 0.582 0.4406 -1.2772 0
1 0.0042 0.7361 0.501 0.095 0.521 0.4980 -0.0357 0

where simdata$data is the generated dataset and simdata$f is the smooth functions \(f_{1}^{\mathrm{M}}(\cdot)\) and \(f_{2}^{\mathrm{M}}(\cdot)\).

x = seq(-1,1,length=1000)
f1 <- fun[[1]](x)
f2 <- fun[[2]](x)
par(mfrow = c(1,2))
plot(x, f1, type = "l")
plot(x, f2, type = "l")

Fit BMAM

The BMAM is based on Bayesian GAMM, which could fitted by brm function in brms package.

library(brms)
model_brms <- brm(bf(y ~  x3 + s(x1) + s(x2) + (1+x3|id)),
                  data = dat, family = "bernoulli", 
                  cores = 4, seed = 4321,
                  warmup = 1000, iter = 2000, chains = 4, 
                  refresh=0, backend = "cmdstanr")
#> Running MCMC with 4 parallel chains...
#> 
#> Chain 3 finished in 25.6 seconds.
#> Chain 4 finished in 25.8 seconds.
#> Chain 1 finished in 26.5 seconds.
#> Chain 2 finished in 27.1 seconds.
#> 
#> All 4 chains finished successfully.
#> Mean chain execution time: 26.2 seconds.
#> Total execution time: 27.4 seconds.

The setting of brm function can be found in Bürkner (2017). We use bf() to specify a formula containing smooth terms and random effects. The GAMM is specified as a logistic model and we sample 4 chains with 2000 iterations (the first 1000 iterations are warm-up). It should be noted that our BMAM functions only support the cmdstanr backend.

Then, we draw the trace plots to check the convergence in Hamiltonian Monte Carlo in brms and the density plots of the posterior samples.

plot(model_brms)

BMAM functions are built on brmsmargins package. We firstly load the dependent packages,

library(brmsmargins)
library(data.table)
library(dplyr)
library(ggplot2)

To fit the BMAM, we use the bmam function.

bmam.fit <- bmam(object = model_brms, k=100, CIType="ETI", CI = 0.95)

Monte Carlo method is used to integrate out the random effects, and we set the random draws as k = 100 (default). The type of credible intervals (CI) can be specified by CIType. In this example, we use Equal-Tailed Interval, ETI (default). Please see here for more CIType options supported in this function. The probability of CI is chosen as 95%.

Summarize model

Let’s now summarize the fitted BMAM.

The summary function will return a list of data frames. Each data frame contains the summary of posterior distribution of the coefficients in linear terms, e.g. Intercept and \(X_3\) in this example.

The first dataframe is for the estimates in marginal model (BMAM), and the second is for the estimates in conditional model (Bayesian GAMM).

bmam.summary <- summary(bmam.fit)
#> 
#> Family: bernoulli 
#> Link function: logit 
#> 
#> y ~ x3 + s(x1) + s(x2) + (1 + x3 | id) 
#> 
#>  Marginal Model 
#>    Parameter       M     Mdn     LL     UL   CI CIType
#> 1: Intercept -0.6994 -0.6965 -1.082 -0.332 0.95    ETI
#> 2:        x3  0.0115  0.0109 -0.218  0.242 0.95    ETI
#> 
#>  Conditional Model 
#>    Parameter       M      Mdn     LL     UL   CI CIType
#> 1: Intercept -1.0624 -1.05754 -1.503 -0.646 0.95    ETI
#> 2:        x3  0.0111  0.00838 -0.332  0.358 0.95    ETI

For each parameter, the returned information include,

  • M: the mean of the posterior samples
  • Mdn: the median of the posterior samples
  • LL: the lower limit of the credible interval
  • UL: the upper limit of the credible interval
  • CI: the probability of credible interval
  • CIType: the type of credible interval

If we are also interested in the smooth functions, we can set plot.smooth as TRUE to plot the estimated smooth functions,

bmam.summary <- summary(bmam.fit, plot.smooth = TRUE)
#> 
#> Family: bernoulli 
#> Link function: logit 
#> 
#> y ~ x3 + s(x1) + s(x2) + (1 + x3 | id) 
#> 
#>  Marginal Model 
#>    Parameter       M     Mdn     LL     UL   CI CIType
#> 1: Intercept -0.6994 -0.6965 -1.082 -0.332 0.95    ETI
#> 2:        x3  0.0115  0.0109 -0.218  0.242 0.95    ETI
#> 
#>  Conditional Model 
#>    Parameter       M      Mdn     LL     UL   CI CIType
#> 1: Intercept -1.0624 -1.05754 -1.503 -0.646 0.95    ETI
#> 2:        x3  0.0111  0.00838 -0.332  0.358 0.95    ETI

In the figures, we demontrate the point estimates (mean of the posterior samples) and 95% CI (ETI). The red one denotes the estimates from BMAM. The blue one denotes the estimates from Bayesian GAMM (conditional model).

The results from summary function also provide the summary of coefficients in smooth terms \(\alpha_l^{\mathrm{M}}\) in BMAM and smooth terms \(\alpha_l^{\mathrm{C}}\) in Bayesian GAM.

The default dimension of coefficients \(\alpha_l^{\mathrm{M}}\) (or \(\alpha_l^{\mathrm{C}}\)) is 10, i.e. \(Q_{1} = Q_{2} = 10\) in this example. The first parameter \(\alpha_{l 1}^{\mathrm{M}}\) (or \(\alpha_{l 1}^{\mathrm{C}}\)) is the intercept, and we only report the coefficients \(\alpha_{l q}^{\mathrm{M}}, \alpha_{l q}^{\mathrm{C}}, q=2,3,\cdots,Q_l\).

The coefficients in smooth terms in BMAM,

knitr::kable(bmam.summary$BMAM$Smooth[[1]], digits = 4)
Parameter M Mdn LL UL CI CIType
Xs_1_alpha -23.929 -23.3924 -38.90 -12.81 0.95 ETI
Zs_1_1_alpha_1 -1.599 -1.2577 -10.65 5.50 0.95 ETI
Zs_1_1_alpha_2 -0.719 -0.7923 -4.89 3.65 0.95 ETI
Zs_1_1_alpha_3 8.046 7.9326 5.16 11.59 0.95 ETI
Zs_1_1_alpha_4 -0.102 -0.0951 -2.35 2.12 0.95 ETI
Zs_1_1_alpha_5 -0.442 -0.5163 -5.36 4.73 0.95 ETI
Zs_1_1_alpha_6 -0.706 -0.6397 -5.86 3.73 0.95 ETI
Zs_1_1_alpha_7 -2.127 -1.8284 -9.69 3.95 0.95 ETI
Zs_1_1_alpha_8 -4.862 -4.3483 -14.45 1.82 0.95 ETI
knitr::kable(bmam.summary$BMAM$Smooth[[2]], digits = 4)
Parameter M Mdn LL UL CI CIType
Xs_2_alpha -0.9211 -0.9319 -9.41 7.45 0.95 ETI
Zs_2_1_alpha_1 0.2962 0.2211 -4.93 5.69 0.95 ETI
Zs_2_1_alpha_2 -1.0763 -0.9099 -5.33 2.29 0.95 ETI
Zs_2_1_alpha_3 0.1172 0.1285 -2.17 2.38 0.95 ETI
Zs_2_1_alpha_4 -4.7807 -4.6938 -7.56 -2.48 0.95 ETI
Zs_2_1_alpha_5 -0.0794 -0.0918 -4.14 4.12 0.95 ETI
Zs_2_1_alpha_6 0.1061 0.0818 -3.96 4.24 0.95 ETI
Zs_2_1_alpha_7 0.0845 0.0804 -5.07 5.15 0.95 ETI
Zs_2_1_alpha_8 0.2724 0.2286 -4.77 5.53 0.95 ETI

The coefficients in smooth terms in Bayesian GAMM (conditional model),

knitr::kable(bmam.summary$Conditional_Model$Smooth[[1]], digits = 4)
Parameter M Mdn LL UL CI CIType
Xs_1_alpha -34.419 -33.724 -54.71 -18.65 0.95 ETI
Zs_1_1_alpha_1 -1.586 -1.203 -14.91 9.10 0.95 ETI
Zs_1_1_alpha_2 -1.614 -1.703 -7.57 4.44 0.95 ETI
Zs_1_1_alpha_3 11.741 11.664 7.68 16.62 0.95 ETI
Zs_1_1_alpha_4 -0.403 -0.388 -3.62 2.76 0.95 ETI
Zs_1_1_alpha_5 -0.592 -0.696 -7.83 7.18 0.95 ETI
Zs_1_1_alpha_6 -0.499 -0.407 -7.83 6.04 0.95 ETI
Zs_1_1_alpha_7 -2.553 -2.141 -13.67 6.69 0.95 ETI
Zs_1_1_alpha_8 -7.081 -6.349 -21.53 2.75 0.95 ETI
knitr::kable(bmam.summary$Conditional_Model$Smooth[[2]], digits = 4)
Parameter M Mdn LL UL CI CIType
Xs_2_alpha -1.4728 -1.4919 -13.55 10.71 0.95 ETI
Zs_2_1_alpha_1 0.3013 0.2706 -7.82 8.41 0.95 ETI
Zs_2_1_alpha_2 -1.6142 -1.3643 -7.79 3.19 0.95 ETI
Zs_2_1_alpha_3 0.1438 0.1565 -3.10 3.50 0.95 ETI
Zs_2_1_alpha_4 -7.3035 -7.2037 -11.24 -3.96 0.95 ETI
Zs_2_1_alpha_5 -0.2824 -0.2975 -6.35 5.84 0.95 ETI
Zs_2_1_alpha_6 -0.0276 -0.0578 -6.03 6.16 0.95 ETI
Zs_2_1_alpha_7 0.3910 0.3588 -7.17 7.97 0.95 ETI
Zs_2_1_alpha_8 -0.1244 -0.1869 -7.86 7.40 0.95 ETI

Visualize model

Beside the figures from summary function, we also provide a plot function to visualize the estimated smooth functions.

The bmam function will call the built-in function generate_pred to generate the predicted data to illustrate the estimated smooth functions. The predicted data are generated according to the model and the data for model fitting. In this example, we use the simulated data simdata$data to fit the model bf(y ~ x3 + s(x1) + s(x2) + (1+x3|id)). The function will generate a sequence with length 100 (default) from the minimum value of x1 to the maximum value of x1 in simdata$data and a sequence with length is 100 (default) from the minimum value of x2 to the maximum value of x2 in simdata$data. The length of sequence can be set as the other values by argument length, for example,

bmam(object = model_brms, k=100, CIType="ETI", CI = 0.95,
     length = 200)

In addition, users can also provide a predicted data, for example,

bmam(object = model_brms, k=100, CIType="ETI", CI = 0.95,
     preddat = user.preddat)

The plot function will show the plots of the estimated smooth functions in BMAM and the conditional model (Bayesian GAMM). If users want to plot the BMAM only, please add an argument conditional = FALSE.

plot(bmam.fit)

We could also provide the true forms of smooth functions simdata$f to the plot function to compare the fitted values and true values,

plot(bmam.fit, smooth.function = simdata$f)

Compare with other models

The argument compared.model in plot function allows us to provide the other models compared with BMAM and the conditional model.

The supported models include

  • Marginal additive models by mam package
  • Generalized additive models by mgcv package
  • Bayesian generalized additive models by brms package

Marginal additive models

The frequentist marginal additive model can be fitted by mam function in mam package (McGee and Stringer 2022). The predicted data in the mam should be the same as that we used in bmam. We can obtain the generated predicted data from the fitted bmam object by bmam.fit$Preddat.

library(mam)
library(mgcv)
themam <- mam(smooth = list(s(x1),s(x2)),
                   re = y ~ (1+x3|id),
                   fe = ~ x3,
                   dat = dat,
                   margdat = dat,
                   preddat = bmam.fit$Preddat,
                   control = mam_control(
                     method = 'trust',
                     varmethod = 1,
                     verbose = FALSE,
                     retcond = TRUE))
plot(bmam.fit, compared.model = themam, smooth.function = simdata$f)

Generalized additive models

We could also compare the the models with GAM. The plot function now supports the generalized additive models by mgcv package and Bayesian generalized additive models by brms package.

gam <- gam(y ~  x3 + s(x1) + s(x2),
           data=dat,family=binomial(),method="REML")
plot(bmam.fit, compared.model = gam, smooth.function = simdata$f)

bgam <- brm(bf(y ~  x3 + s(x1) + s(x2)), data = dat, 
            family = "bernoulli", 
            cores = 4, seed = 4321, warmup = 1000, 
            iter = 2000, chains = 4,
            refresh=0, backend = "cmdstanr")
#> Running MCMC with 4 parallel chains...
#> 
#> Chain 2 finished in 14.3 seconds.
#> Chain 4 finished in 15.2 seconds.
#> Chain 1 finished in 16.1 seconds.
#> Chain 3 finished in 16.8 seconds.
#> 
#> All 4 chains finished successfully.
#> Mean chain execution time: 15.6 seconds.
#> Total execution time: 17.1 seconds.
plot(bmam.fit, compared.model = bgam, smooth.function = simdata$f)

Center the smooth terms

In some cases, we may want to center the smooth functions.

We could add an argument centered = TRUE in bmam function, to estimate the centered smooth functions.

bmam.fit.centered <- bmam(object = model_brms, k=100, CIType="ETI", CI = 0.95, centered = TRUE)

We can also fit a MAM with centered smooth functions,

themam.centered <- mam(smooth = list(s(x1),s(x2)),
                   re = y ~ (1+x3|id),
                   fe = ~ x3,
                   dat = dat,
                   margdat = dat,
                   preddat = bmam.fit.centered$Preddat,
                   control = mam_control(
                     centered = TRUE,
                     method = 'trust',
                     varmethod = 1,
                     verbose = FALSE,
                     retcond = TRUE))

Then, we can draw the plots by plot function introduced above to compare the centered smooth terms in BMAM, the conditional model and MAM.

plot(bmam.fit.centered, compared.model = themam.centered)
#> BMAM is centered.

References

Bürkner, Paul-Christian. 2017. “Brms: An r Package for Bayesian Multilevel Models Using Stan.” Journal of Statistical Software 80: 1–28.
Heagerty, Patrick J, and Scott L Zeger. 2000. “Marginalized Multilevel Models and Likelihood Inference.” Statistical Science 15 (1): 1–26.
Hedeker, Donald, Stephen HC du Toit, Hakan Demirtas, and Robert D Gibbons. 2018. “A Note on Marginalization of Regression Parameters from Mixed Models of Binary Outcomes.” Biometrics 74 (1): 354–61.
McGee, Glen, and Alex Stringer. 2022. “Flexible Marginal Models for Dependent Data.” arXiv Preprint arXiv:2204.07188.
Wood, Simon N. 2006. Generalized Additive Models: An Introduction with r. chapman; hall/CRC.